import torch
import torch.nn as nn
from torch_npu.contrib import transfer_to_npu


print("NPU available:", torch.npu.is_available())
print("NPU device count:", torch.npu.device_count())


device = torch.device("npu:0" if torch.npu.is_available() else "cpu")
print("Using device:", device)

model = nn.Linear(10, 1).to(device)
x = torch.randn(5, 10).to(device)
y = model(x)

print("Output:", y)